#!/usr/bin/env python3

# Simple utility for programming SPI flash via a UART shell in the bootloader.
# The bootloader possesses a 1-page buffer which we can write to, read from and
# checksum.
# It also possesses commands for transfers between this buffer and the SPI flash,
# flash erasure, and launching the 2nd stage boot code.

import argparse
import os
import serial
import serial.tools.list_ports
import sys
import time

PAGESIZE   = 2 ** 8
SECTORSIZE = 2 ** 12
BLOCKSIZE  = 2 ** 16

SRAM_LOAD_ADDR = 0x2 << 28
SRAM_EXEC_ADDR = SRAM_LOAD_ADDR + 0xc0

class ProtocolError(Exception):
	pass

class CMD:
	NOP          = '\n'.encode()
	WRITE_BUF    = 'w'.encode()
	READ_BUF     = 'r'.encode()
	GET_CHECKSUM = 'c'.encode()
	SET_ADDR     = 'a'.encode()
	WRITE_FLASH  = 'W'.encode()
	READ_FLASH   = 'R'.encode()
	ERASE_SECTOR = 'E'.encode()
	ERASE_BLOCK  = 'B'.encode()
	BOOT_2ND     = '2'.encode()
	LOAD_MEM     = 'l'.encode()
	EXEC_MEM     = 'x'.encode()
	ACK          = ':'.encode()

def reset_shell(port):
	while True:
		port.write(CMD.NOP * (PAGESIZE + 1))
		port.flushOutput()
		time.sleep(0.02)
		port.flushInput()
		port.write(CMD.NOP)
		time.sleep(0.02)
		if port.readable() and port.read_all().endswith(CMD.ACK):
			break

def check_ack(port):
	resp = port.read()
	if len(resp) == 0 or resp != CMD.ACK:
		print("Got '{!r}'".format(resp))
		raise ProtocolError()

def write_buf(port, data, check=True):
	assert(len(data) == 256)
	port.write(CMD.WRITE_BUF)
	port.write(data)
	if check:
		check_ack(port)
	# TODO checksum

def read_buf(port, pipelined=False):
	if not pipelined:
		port.write(CMD.READ_BUF)
	resp = port.read(PAGESIZE)
	if len(resp) != PAGESIZE:
		raise ProtocolError()
	return resp
	# TODO checksum

def set_addr(port, addr):
	addr_b = bytes([
		(addr >> 16) & 0xff,
		(addr >>  8) & 0xff,
		(addr >>  0) & 0xff
	])
	port.write(CMD.SET_ADDR + addr_b)
	echo = port.read(3)
	if echo != addr_b:
		raise ProtocolError()
	check_ack(port)

def progress(header, frac, width=40):
	n = int(width * frac)
	sys.stdout.write("\r" + header + "▕" + "▒" * n + " " * (width - n) + "▏")

def read_flash(port, addr, size):
	set_addr(port, addr)
	data = bytes()
	addr_range = range(addr, addr + size, PAGESIZE)
	for a in addr_range:
		port.write(CMD.READ_FLASH)
		port.write(CMD.READ_BUF) # Should have space for 2 cmds in FIFO. Hoist this one up from read_buf()
		progress("Read:  ", (a - addr) / size)
		check_ack(port)
		data += read_buf(port, pipelined=True)
	progress("Read:  ", 1)
	if len(data) > size:
		data = data[:size] # ouch
	return data

def write_flash(port, addr, data):
	assert(addr % PAGESIZE == 0)
	if len(data) % PAGESIZE != 0:
		data += bytes(PAGESIZE - len(data) % PAGESIZE)
	set_addr(port, addr)
	for i in range(len(data) // PAGESIZE):
		write_buf(port, data[i * PAGESIZE : (i + 1) * PAGESIZE], check=False)
		port.write(CMD.WRITE_FLASH)
		progress("Write: ", i / (len(data) // PAGESIZE))
		check_ack(port)
		check_ack(port)
	progress("Write: ", 1)

def erase_flash(port, addr, size):
	end = addr + size
	if end % SECTORSIZE != 0:
		end += SECTORSIZE - end % SECTORSIZE
	start = addr - addr % SECTORSIZE
	set_addr(port, start)
	a = start
	while a < end:
		progress("Erase: ", (a - start) / (end - start))
		if end - a >= BLOCKSIZE and a % BLOCKSIZE == 0:
			port.write(CMD.ERASE_BLOCK)
			a += BLOCKSIZE
		else:
			port.write(CMD.ERASE_SECTOR)
			a += SECTORSIZE
		check_ack(port)
	progress("Erase: ", 1)

def run_2nd_stage(port):
	set_addr(port, 0x123456)
	port.write(CMD.BOOT_2ND)
	check_ack(port)

def load_mem(port, data):
	set_addr(port, 0xb00000 | SRAM_LOAD_ADDR)
	port.write(CMD.LOAD_MEM + bytes([
		(len(data) >> 16) & 0xff,
		(len(data) >>  8) & 0xff,
		(len(data) >>  0) & 0xff
	]))
	check_ack(port)
	port.write(data) # yup
	check_ack(port)

def exec_mem(port):
	set_addr(port, 0xb00000 | SRAM_EXEC_ADDR)
	port.write(CMD.EXEC_MEM)
	check_ack(port)

def any_int(x):
	return int(x, 0)

if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument("file", nargs="?", help="Filename to read from or dump to.")
	parser.add_argument("--uart", "-u", help="Path to UART device (e.g. /dev/ttyUSB0)")
	parser.add_argument("--baud", "-b", help="Baud rate for UART. Default 1 Mbaud")
	parser.add_argument("--write", "-w", action="store_true",
		help="Write to flash (default is read)")
	parser.add_argument("--verify", "-v", action="store_true",
		help="Verify contents after programming (use with --write)")
	parser.add_argument("--run", "-r", action="store_true",
		help="Run code from flash after programming")
	parser.add_argument("--execute", "-x", action="store_true",
		help="Load code directly into SRAM and execute it. Not to be combined with -w,-v,-r")
	parser.add_argument("--start", "-s", type=any_int, help="Base address to start read/write")
	parser.add_argument("--len", "-l", type=any_int, help="Number of bytes to read, or override file length for write.")
	args = parser.parse_args()

	# Do some simple parameter checking to make it harder to accidentally trash your device
	if args.write and args.file is None:
		sys.exit("Must specify filename for write")
	if args.start is None:
		args.start = 0
	if args.uart is None:
		args.uart = sorted(p.device for p in serial.tools.list_ports.comports())[-1]
	if args.baud is None:
		args.baud = 1000000
	if args.verify and not args.write:
		sys.exit("Verify is only valid for a write operation")
	if args.run and not args.write:
		sys.exit("Run is only valid for a write operation")
	if args.write and (args.start % PAGESIZE != 0):
		sys.exit("Writes must be aligned on a {}-byte boundary.".format(PAGESIZE))
	if args.execute and (args.write or args.verify or args.run or args.start or args.len):
		sys.exit("--execute is not compatible with flash-related arguments")
	if args.write or args.execute:
		try:
			filesize = os.stat(args.file).st_size
			if args.len is None:
				args.len = filesize
		except:
			sys.exit("Could not open file '{}'".format(args.file))
	else:
		if args.len is None:
			args.len = PAGESIZE

	# Don't want to overwrite their image if they forget -w
	if not (args.write or args.execute) and args.file and os.path.exists(args.file):
		resp = ""
		while not resp in ["y", "n"]:
			resp = input("File '{}' exists. Overwrite? (y/n) ".format(args.file))
			resp = resp.lower().strip()
		if resp == "n":
			sys.exit(0)

	# Summarise what we're about to do
	if args.execute:
		print("Loading {} bytes to SRAM and running".format(args.len))
	else:
		print("{} {} bytes {} {}, starting at address 0x{:06x}".format(
			"Writing" + (" and verifying" if args.verify else "") if args.write else "Reading",
			args.len,
			"from" if args.write else "to",
			"stdout" if args.file is None else "file " + args.file,
			args.start,
			", and verifying" if args.verify else ""
		))

	# And then do it
	# need to allow for page erase, upward of 60 ms. (Uh, actually it seems to be much longer)
	port = serial.Serial(args.uart, args.baud, timeout=1)
	print("Waiting for bootloader on {}...".format(port.name))
	reset_shell(port)
	print("")

	if args.execute:
		print("Loading...")
		load_mem(port, open(args.file, "rb").read())
		print("Load ok, running...")
		exec_mem(port)
	elif args.write:
		data = open(args.file, "rb").read()
		if len(data) > args.len:
			data = data[:args.len]
		else:
			data = data + bytes(args.len - len(data)) # zero padding

		start = time.time()
		erase_flash(port, args.start, args.len)
		print(" Took {:.1f} s\n".format(time.time() - start))
		start = time.time()
		write_flash(port, args.start, data)
		print(" Took {:.1f} s\n".format(time.time() - start))
		if args.verify:
			start = time.time()
			readback = read_flash(port, args.start, args.len)
			print(" Took {:.1f} s".format(time.time() - start))
			if readback != data:
				sys.exit("Verification failed.") # TODO: better info
		if args.run:
			print("Launching flash second stage")
			run_2nd_stage(port)
		print("Done")
	else:
		start = time.time()
		data = read_flash(port, args.start, args.len)
		print(" Took {:.1f} s\nDone".format(time.time() - start))
		if args.file:
			open(args.file, "wb").write(data)
		else:
			for i, byte in enumerate(data):
				if i % 8 == 0:
					sys.stdout.write("{:06x}: ".format(args.start + i))
				sys.stdout.write("{:02x}".format(byte))
				sys.stdout.write("\n" if i % 8 == 7 else " ")
			sys.stdout.write("\n")
